{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "def train():\n",
    "    import numpy as np\n",
    "    import tensorflow as tf\n",
    "    import wandb\n",
    "    config_defaults = {\n",
    "        'hidden_nodes': 128\n",
    "    }\n",
    "    wandb.init(config=config_defaults)\n",
    "    \n",
    "    fashion_mnist = tf.keras.datasets.fashion_mnist\n",
    "    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
    "    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
    "\n",
    "    train_images = train_images / 255.0\n",
    "    test_images = test_images / 255.0\n",
    "\n",
    "    model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "        tf.keras.layers.Dense(wandb.config.hidden_nodes, activation=tf.nn.relu),\n",
    "        tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "    ])\n",
    "\n",
    "    model.compile(optimizer='adam',\n",
    "                  loss='sparse_categorical_crossentropy',\n",
    "                  metrics=['accuracy'])\n",
    "    \n",
    "    model.fit(train_images, train_labels, epochs=5, callbacks=[wandb.keras.WandbCallback(input_type=\"images\", save_model=False)],\n",
    "                  validation_data=(test_images, test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_config = {\n",
    "    'method': 'grid',\n",
    "    'parameters': {\n",
    "        'hidden_nodes': {\n",
    "            'values': [32, 64, 96, 128, 256]\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Create sweep with ID: 3qwdp8z9\n",
      "Sweep URL: https://app.wandb.ai/qualcomm/sweeps-sep26/sweeps/3qwdp8z9\n"
     ]
    }
   ],
   "source": [
    "import wandb\n",
    "sweep_id = wandb.sweep(sweep_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wandb: Agent Starting Run: fwjp479r with config:\n",
      "\thidden_nodes: 32\n",
      "wandb: Agent Started Run: fwjp479r\n"
     ]
    }
   ],
   "source": [
    "wandb.agent(sweep_id, function=train)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
